from tensorflow.keras.layers import (Layer,Input,LayerNormalization,
                                    Dense,Dropout,Conv2D,)
from tensorflow.keras.activations import gelu
import tensorflow as tf
import numpy as np
from nets.custom_functions import (drop_path,window_partition, window_reverse)
 
 
class MLPLayer(Layer):
    def __init__(self,hidden_features=None,drop_rate=0.,**kwargs):
        super(MLPLayer,self).__init__(**kwargs)
 
        self.hidden_features = hidden_features
        self.drop_rate = drop_rate
 
        self.fc1 = Dense(self.hidden_features)
        self.drop = Dropout(self.drop_rate)
    
    def get_config(self):
        config = super(MLPLayer,self).get_config()
        config.update({"hidden_features":self.hidden_features,
                       "out_features":self.out_features,
                       "drop_rate":self.drop_rate})
        return config
    
    def build(self, input_shape):
        self.out_features = input_shape[-1]
        self.fc2 = Dense(self.out_features)
 
    def call(self,inputs):
        x = self.fc1(inputs)
        x = gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
 
        return x
 
class WindowAttentionLayer(Layer):
    def __init__(self,dim,window_size,num_heads,qkv_bias=True,
                qk_scale=None,attn_drop_rate=0.,
                proj_drop_rate=0.,**kwargs):
        super(WindowAttentionLayer,self).__init__(**kwargs)
 
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim//num_heads
        self.scale = qk_scale or (self.head_dim ** (-0.5))
        self.qkv_bias = qkv_bias
        self.attn_drop_rate = attn_drop_rate
        self.proj_drop_rate = proj_drop_rate
 
        self.qkv = Dense(self.dim*3,use_bias=self.qkv_bias)
        self.attn_drop = Dropout(self.attn_drop_rate)
        self.proj = Dense(self.dim)
        self.proj_drop = Dropout(self.proj_drop_rate)
 
 
    def get_config(self):
        config = super(WindowAttentionLayer,self).get_config()
        config.update({"self.dim":self.dim,
                       "window_size":self.window_size,
                       "num_heads":self.num_heads,
                       "head_dim":self.head_dim,
                       "scale":self.scale,
                       "qkv_bias":self.qkv_bias,
                       "attn_drop_rate":self.attn_drop_rate,
                       "proj_drop_rate":self.proj_drop_rate})
        return config
    
    def build(self, input_shape):
        self.relative_position_bias_table = self.add_weight(
            shape=[(2*self.window_size[0]-1)*(2*self.window_size[1]-1),
                    self.num_heads],
            initializer=tf.initializers.Zeros(),
            trainable=True,
            name='relative_position'
        )
 
        coords_h = np.arange(self.window_size[0]) # 0-6
        coords_w = np.arange(self.window_size[1])
        coords = np.stack(np.meshgrid(coords_h,coords_w,indexing='ij'))
        coords_flatten = coords.reshape(2,-1)
        relative_coords = coords_flatten[:,:,None] - coords_flatten[:,None,:]
        relative_coords = relative_coords.transpose([1,2,0])
        relative_coords[:,:,0] +=self.window_size[0] - 1
        relative_coords[:,:,1] +=self.window_size[1] - 1
        relative_coords[:,:,0] *= 2*self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1).astype(np.int64)
        self.relative_position_index = tf.Variable(
            initial_value=tf.convert_to_tensor(relative_position_index),
            trainable=False
        )
        self.built = True
 
    def call(self,x,mask=None):
        _,N,C = x.shape.as_list()
        qkv = self.qkv(x)
        q,k,v = tf.split(qkv,3,axis=-1) # -1,49,96
        # -1,8,49,12
        q = tf.transpose(tf.reshape(q,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
        k = tf.transpose(tf.reshape(k,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
        v = tf.transpose(tf.reshape(v,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
        
 
        q = self.scale * q
        # -> (-1, 8, 49, 49)
        attn = tf.matmul(q,k,transpose_b=True)
        # print(f'q*k之后的shape: {attn.shape}')
        relative_position_bias = tf.gather(
            self.relative_position_bias_table,
            tf.reshape(self.relative_position_index,shape=[-1])
        )
        relative_position_bias = tf.reshape(relative_position_bias,
                    shape=[self.window_size[0]*self.window_size[1],
                           self.window_size[0]*self.window_size[1],
                           -1])
        relative_position_bias = tf.transpose(relative_position_bias,
                                            [2,0,1])
        # print(f'relative_pos的shape: {relative_position_bias.shape}')
        attn = attn + tf.expand_dims(relative_position_bias,axis=0)
        # print(f'in winattn: {mask.shape}')
        if type(mask) != type(None):
            mask = tf.convert_to_tensor(mask)
            nW = mask.shape[0]
            
            attn = tf.reshape(attn,shape=[-1,nW,self.num_heads,N,N]) + \
                    tf.cast(tf.expand_dims(tf.expand_dims(mask,axis=1),axis=0),
                    attn.dtype)
            attn = tf.reshape(attn,shape=[-1,self.num_heads,N,N])
            attn = tf.nn.softmax(attn,axis=-1)
        else:
            attn = tf.nn.softmax(attn,axis=-1)
        
        attn = self.attn_drop(attn)
        # -> -1,49,8,12
        x = tf.transpose((attn@v),[0,2,1,3])
        # -> -1,49,96
        x = tf.reshape(x,shape=[-1,N,C])
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class DropPathLayer(Layer):
    def __init__(self,drop_prob=None,**kwargs):
        super(DropPathLayer,self).__init__(**kwargs)
        self.drop_prob = drop_prob
    
    def call(self,x):
        return drop_path(x,self.drop_prob)
    
    def get_config(self):
        config = super(DropPathLayer,self).get_config()
        config.update({"drop_prob":self.drop_prob})
        return config
 
 
class SwinTransformerBlockLayer(Layer):
    def __init__(self,dim,input_resolution,num_heads,window_size=7,
                shift_size=0,mlp_ratio=4.,qkv_bias=True,qk_scale=None,
                drop_rate=0.,attn_drop_rate=0.,drop_path_prob=0.,
                **kwargs):
        super(SwinTransformerBlockLayer,self).__init__(**kwargs)
 
        self.dim = dim
        # print(f'self.dim={self.dim}')
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
 
        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.drop_rate = drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.drop_path_prob = drop_path_prob
 
 
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        
        assert 0<=self.shift_size<self.window_size,'偏移必须在0-window_size之间'
 
        self.norm1 = LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttentionLayer(self.dim,(self.window_size,self.window_size),
                                        self.num_heads,self.qkv_bias,self.qk_scale,
                                        self.attn_drop_rate,self.drop_rate)
        self.drop_path = DropPathLayer(self.drop_path_prob)
        self.norm2 = LayerNormalization(epsilon=1e-5)
        mlp_hidden_dim = int(dim*self.mlp_ratio)
        self.mlp = MLPLayer(hidden_features=mlp_hidden_dim,
                            drop_rate=self.drop_rate)
        
    def build(self,input_shape):
        if self.shift_size > 0:
            H,W = self.input_resolution
            img_mask = np.zeros([1,H,W,1])
            h_slices = (slice(0,-self.window_size),
                        slice(-self.window_size,-self.window_size),
                        slice(-self.shift_size,None))
            w_slices = (slice(0,-self.window_size),
                        slice(-self.window_size,-self.window_size),
                        slice(-self.shift_size,None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:,h,w,:] = cnt
            
            img_mask = tf.convert_to_tensor(img_mask)            
            mask_windows = window_partition(img_mask,self.window_size)
            # print(f'in if {mask_windows.shape}')
            mask_windows = tf.reshape(mask_windows,shape=[
                -1,self.window_size*self.window_size
            ])
            # -1,1,49 - -1,49,1 => -1,49,49
            # print(f'in if {mask_windows.shape}')
            attn_mask = tf.expand_dims(mask_windows,axis=1) - tf.expand_dims(mask_windows,axis=2)
            
            attn_mask = tf.where(attn_mask!=0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask==0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(initial_value=attn_mask,
                            trainable=False)
            # print('in if',self.attn_mask.shape)
        else:
            self.attn_mask = None
            # print('in else')
        self.built = True
        # print(f'in build, attn_mask={self.attn_mask}')
 
 
    def get_config(self):
        config = super(SwinTransformerBlockLayer,self).get_config()
        config.update({"dim":self.dim,
                       "input_resolution":self.input_resolution,
                       "num_heads":self.num_heads,
                       "window_size":self.window_size,
                       "shift_size":self.shift_size,
                       "mlp_ratio":self.mlp_ratio,
                       "qkv_bias":self.qkv_bias,
                       "qk_scale":self.qk_scale,
                       "drop_rate":self.drop_rate,
                       "attn_drop_rate":self.attn_drop_rate,
                       "drop_path_prob":self.drop_path_prob,
                       })
        return config
    
    def call(self,x):
        # print(f'in call: {self.attn_mask}')
        # print(x.shape)
        # input('zz')
        H,W = self.input_resolution
        _,L,C = x.shape.as_list()
        assert L == H*W, f'input feature has wrong size,L={L},H,W={H},{W}'
 
        shortcut = x
        x = self.norm1(x)
        x = tf.reshape(x,shape=[-1,H,W,C])
        # print(x.shape)
 
        # cyclic shift
        if self.shift_size > 0:
            shifted_x = tf.roll(x,shift=[-self.shift_size,-self.shift_size],
                                axis=[1,2])
        else:
            shifted_x = x
        
        # partition windows
        # print(shifted_x.shape)
        x_windows = window_partition(shifted_x,self.window_size)
        # print(x_windows.shape)
        # input('zz')
        x_windows = tf.reshape(x_windows, 
                        shape=[-1,self.window_size*self.window_size,C])
        # print(x_windows.shape)
        # w-msa/sw-msa
        # print('在做注意力之前的',x_windows.shape)
        attn_windows = self.attn(x_windows,mask=self.attn_mask)
        # print(f'做完msa之后的shape: {attn_windows.shape}')
 
        # merge windows
        attn_windows = tf.reshape(attn_windows,
                                shape=[-1,self.window_size,self.window_size,C])
        # print(attn_windows.shape,H,W,C)#(None, 8, 8, 96)
        # input('zz')
        shifted_x = window_reverse(attn_windows,self.window_size,H,W,C)
        # print(shifted_x.shape)
 
        # reverse cyclic shift
        if self.shift_size > 0:
            x = tf.roll(shifted_x,
                        shift=[self.shift_size,self.shift_size],
                        axis=[1,2])
        else:
            x = shifted_x
        
        x = tf.reshape(x,shape=[-1,H*W,C])
        
        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
 
        return x
 
class PatchMergingLayer(Layer):
    def __init__(self,input_resolution,dim,**kwargs):
        super(PatchMergingLayer,self).__init__(**kwargs)
 
        self.input_resolution = input_resolution
        self.dim = dim
 
        self.norm = LayerNormalization(epsilon=1e-5)
        self.reduction = Dense(2*self.dim,use_bias=False)
 
 
    def get_config(self):
        config = super(PatchMergingLayer,self).get_config()
        config.update({"input_resolution":self.input_resolution,
                       "dim":self.dim})
 
        return config
 
    def call(self,x):
        H,W = self.input_resolution
        B,L,C = x.shape.as_list()
        assert L==H*W, 'input feature has wrong size'
        assert H%2==0 and W%2==0, f'x size ({H}*{W}) are not even.'
 
        x = tf.reshape(x,shape=[-1,H,W,C])
 
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = tf.concat([x0, x1, x2, x3], axis=-1)
        x = tf.reshape(x, shape=[-1, (H // 2) * (W // 2), 4 * C])
 
        x = self.norm(x)
        x = self.reduction(x)
 
        return x
    
 
class PatchEmbeddingLayer(Layer):
    def __init__(self,img_size=[224,224],patch_size=[4,4],
                embed_dims=96,**kwargs):
        super(PatchEmbeddingLayer,self).__init__(**kwargs)
 
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dims = embed_dims
 
        patchs_resolution = [self.img_size[0]//self.patch_size[0],
                            self.img_size[1]//self.patch_size[1]]
 
        self.patchs_resolution = patchs_resolution
        self.num_patches = patchs_resolution[0] * patchs_resolution[1]
 
        self.proj = Conv2D(self.embed_dims,self.patch_size,
                            self.patch_size)
        
 
    def get_config(self):
        config = super(PatchEmbeddingLayer,self).get_config()
        config.update({"img_size":self.img_size,
                       "patch_size":self.patch_size,
                       "embed_dims":self.embed_dims,
                       "patchs_resolution":self.patchs_resolution,
                       "num_patches":self.num_patches})
        return config
    
    def call(self,x):
        _,H,W,C = x.shape.as_list()
        assert H==self.img_size[0] and W==self.img_size[1], \
            f'input img size ({H}*{W}) does not match model ({self.img_size[0]}*{self.img_size[1]}).'
        
        x = self.proj(x)
        _,h,w,c = x.shape.as_list()
        x = tf.reshape(x,shape=[-1,h*w,c])
 
        return x
 
 
# if __name__ == '__main__':
#     input_shape = [512,512,1]
#     patch_size = 4
#     window_size = 8
#     if input_shape[0]%4!=0:
#         raise ValueError("patch_embedding不能整除")
#     if (input_shape[0]//patch_size)%window_size!=0:
#         raise ValueError("划分窗口不能整除")

#     inputs = Input(shape=input_shape)
#     # 做patch_embedding
#     x = PatchEmbeddingLayer(img_size=input_shape[:-1],patch_size=[patch_size,patch_size])(inputs)
#     print(f'patch_embedding之后的输出大小(b,56*56,96): {x.shape}')
 
#     # 经过一对swin transformer block
#     # shift_size=0; num_heads=3; window_size=7; mlp_ratio=4
#     x = SwinTransformerBlockLayer(96,[input_shape[0]//patch_size,input_shape[1]//patch_size],
#                                 num_heads=3,
#                                 window_size=window_size,
#                                 shift_size=0,
#                                 mlp_ratio=4)(x)

#     print(f'经过一个没有shift的STB之后的输出大小(b,56*56,96): {x.shape}')
#     # shift_size=3; num_heads=3; window_size=7; mlp_ratio=4
#     x = SwinTransformerBlockLayer(96,[input_shape[0]//patch_size,input_shape[1]//patch_size],
#                                 num_heads=3,
#                                 window_size=window_size,
#                                 shift_size=window_size//2,
#                                 mlp_ratio=4)(x)
#     print(f'经过一个经过shift的STB之后的输出大小(b,56*56,96): {x.shape}')
 
#     # 经过patch_mering,h,w减倍，通道加倍
#     x = PatchMergingLayer([input_shape[0]//patch_size,input_shape[1]//patch_size],96)(x)
#     print(f'经过patch_merging之后的输出大小(b,28*28,96*2): {x.shape}')
 